{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Decision Tree Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.sql import SparkSession\n",
    "\n",
    "spark = SparkSession \\\n",
    "    .builder \\\n",
    "    .appName(\"Python Spark Random Forest Regression\") \\\n",
    "    .config(\"spark.some.config.option\", \"some-value\") \\\n",
    "    .getOrCreate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "df = spark.read.format('com.databricks.spark.csv').\\\n",
    "                               options(header='true', \\\n",
    "                               inferschema='true').load(\"./data/WineData.csv\",header=True);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.linalg import Vectors # !!!!caution: not from pyspark.mllib.linalg import Vectors\n",
    "from pyspark.ml import Pipeline\n",
    "from pyspark.ml.feature import IndexToString,StringIndexer, VectorIndexer\n",
    "from pyspark.ml.classification import DecisionTreeClassifier\n",
    "from pyspark.ml.tuning import CrossValidator, ParamGridBuilder\n",
    "from pyspark.ml.evaluation import MulticlassClassificationEvaluator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def transData(data):\n",
    "    return data.rdd.map(lambda r: [Vectors.dense(r[:-1]),r[-1]]).toDF(['features','label'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "data = transData(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------+-----+\n",
      "|            features|label|\n",
      "+--------------------+-----+\n",
      "|[7.4,0.7,0.0,1.9,...|    5|\n",
      "|[7.8,0.88,0.0,2.6...|    5|\n",
      "|[7.8,0.76,0.04,2....|    5|\n",
      "+--------------------+-----+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "data.show(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------+-----+------------+\n",
      "|            features|label|indexedLabel|\n",
      "+--------------------+-----+------------+\n",
      "|[7.4,0.7,0.0,1.9,...|    5|         0.0|\n",
      "|[7.8,0.88,0.0,2.6...|    5|         0.0|\n",
      "|[7.8,0.76,0.04,2....|    5|         0.0|\n",
      "|[11.2,0.28,0.56,1...|    6|         1.0|\n",
      "|[7.4,0.7,0.0,1.9,...|    5|         0.0|\n",
      "|[7.4,0.66,0.0,1.8...|    5|         0.0|\n",
      "+--------------------+-----+------------+\n",
      "only showing top 6 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Index labels, adding metadata to the label column\n",
    "labelIndexer = StringIndexer(inputCol='label',\n",
    "                             outputCol='indexedLabel').fit(data)\n",
    "labelIndexer.transform(data).show(6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------+-----+--------------------+\n",
      "|            features|label|     indexedFeatures|\n",
      "+--------------------+-----+--------------------+\n",
      "|[7.4,0.7,0.0,1.9,...|    5|[7.4,0.7,0.0,1.9,...|\n",
      "|[7.8,0.88,0.0,2.6...|    5|[7.8,0.88,0.0,2.6...|\n",
      "|[7.8,0.76,0.04,2....|    5|[7.8,0.76,0.04,2....|\n",
      "|[11.2,0.28,0.56,1...|    6|[11.2,0.28,0.56,1...|\n",
      "|[7.4,0.7,0.0,1.9,...|    5|[7.4,0.7,0.0,1.9,...|\n",
      "|[7.4,0.66,0.0,1.8...|    5|[7.4,0.66,0.0,1.8...|\n",
      "+--------------------+-----+--------------------+\n",
      "only showing top 6 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Automatically identify categorical features, and index them.\n",
    "# Set maxCategories so features with > 4 distinct values are treated as continuous.\n",
    "featureIndexer =VectorIndexer(inputCol=\"features\", \\\n",
    "                              outputCol=\"indexedFeatures\", \\\n",
    "                              maxCategories=4).fit(data)\n",
    "\n",
    "featureIndexer.transform(data).show(6)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Train a DecisionTree model\n",
    "dTree = DecisionTreeClassifier(labelCol='indexedLabel', featuresCol='indexedFeatures')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Convert indexed labels back to original labels.\n",
    "labelConverter = IndexToString(inputCol=\"prediction\", outputCol=\"predictedLabel\",\n",
    "                               labels=labelIndexer.labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Chain indexers and tree in a Pipeline\n",
    "pipeline = Pipeline(stages=[labelIndexer, featureIndexer, dTree,labelConverter])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Split the data into training and test sets (30% held out for testing)\n",
    "(trainingData, testData) = data.randomSplit([0.7, 0.3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Train model.  This also runs the indexers.\n",
    "model = pipeline.fit(trainingData)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Make predictions.\n",
    "predictions = model.transform(testData)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------+-----+--------------+\n",
      "|            features|label|predictedLabel|\n",
      "+--------------------+-----+--------------+\n",
      "|[5.0,0.38,0.01,1....|    6|             6|\n",
      "|[5.0,0.42,0.24,2....|    8|             6|\n",
      "|[5.0,0.74,0.0,1.2...|    6|             6|\n",
      "|[5.0,1.04,0.24,1....|    5|             6|\n",
      "|[5.1,0.42,0.0,1.8...|    7|             6|\n",
      "+--------------------+-----+--------------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Select example rows to display.\n",
    "predictions.select(\"features\",\"label\",\"predictedLabel\").show(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predictions accuracy = 0.532, Test Error = 0.468\n"
     ]
    }
   ],
   "source": [
    "# Select (prediction, true label) and compute test error\n",
    "evaluator = MulticlassClassificationEvaluator(\n",
    "    labelCol=\"indexedLabel\", predictionCol=\"prediction\", metricName=\"accuracy\")\n",
    "accuracy = evaluator.evaluate(predictions)\n",
    "print(\"Predictions accuracy = %g, Test Error = %g\" % (accuracy,(1.0 - accuracy)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cross-Validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Search through decision tree's maxDepth parameter for best model\n",
    "paramGrid = ParamGridBuilder().addGrid(dTree.maxDepth, [2,3,4,5,6,7]).build()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Set F-1 score as evaluation metric for best model selection\n",
    "evaluator = MulticlassClassificationEvaluator(labelCol='indexedLabel',\n",
    "                                              predictionCol='prediction', metricName='f1')    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Set up 3-fold cross validation\n",
    "crossval = CrossValidator(estimator=pipeline,\n",
    "                          estimatorParamMaps=paramGrid,\n",
    "                          evaluator=evaluator,\n",
    "                          numFolds=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Train model.  This also runs the indexers.\n",
    "dTree_model = crossval.fit(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Fetch best model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PipelineModel_42999d14f2640615d800"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Fetch best model\n",
    "tree_model = dTree_model.bestModel\n",
    "tree_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Make predictions.\n",
    "predictions = tree_model.transform(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------+-----+--------------+\n",
      "|            features|label|predictedLabel|\n",
      "+--------------------+-----+--------------+\n",
      "|[7.4,0.7,0.0,1.9,...|    5|             5|\n",
      "|[7.8,0.88,0.0,2.6...|    5|             5|\n",
      "|[7.8,0.76,0.04,2....|    5|             5|\n",
      "|[11.2,0.28,0.56,1...|    6|             6|\n",
      "|[7.4,0.7,0.0,1.9,...|    5|             5|\n",
      "+--------------------+-----+--------------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Select example rows to display.\n",
    "predictions.select(\"features\",\"label\",\"predictedLabel\").show(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predictions accuracy = 0.766729, Test Error = 0.233271\n"
     ]
    }
   ],
   "source": [
    "# Select (prediction, true label) and compute test error\n",
    "evaluator = MulticlassClassificationEvaluator(\n",
    "    labelCol=\"indexedLabel\", predictionCol=\"prediction\", metricName=\"accuracy\")\n",
    "accuracy = evaluator.evaluate(predictions)\n",
    "print(\"Predictions accuracy = %g, Test Error = %g\" % (accuracy,(1.0 - accuracy)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
